from custom_test_runner import TOTAL_COST, COMPLETION_TOKENS, PROMPT_TOKENS
from rasa.dialogue_understanding.generator import SingleStepLLMCommandGenerator
import litellm
from typing import Any, List, Optional, Text
from functools import partial

import structlog
from rasa.dialogue_understanding.commands import (
    Command
)
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
from rasa.shared.core.flows import FlowsList
from rasa.shared.core.trackers import DialogueStateTracker

from rasa.shared.nlu.training_data.message import Message
from rasa.shared.utils.llm import (
    DEFAULT_OPENAI_CHAT_MODEL_NAME_ADVANCED,
    DEFAULT_OPENAI_MAX_GENERATED_TOKENS,
)
COMMAND_PROMPT_FILE_NAME = "command_prompt.jinja2"
DEFAULT_LLM_CONFIG = {
    "provider": "openai",
    "model": DEFAULT_OPENAI_CHAT_MODEL_NAME_ADVANCED,
    "temperature": 0.0,
    "max_tokens": DEFAULT_OPENAI_MAX_GENERATED_TOKENS,
    "timeout": 20,
}
LLM_CONFIG_KEY = "llm"
USER_INPUT_CONFIG_KEY = "user_input"
FLOW_RETRIEVAL_KEY = "flow_retrieval"
FLOW_RETRIEVAL_ACTIVE_KEY = "active"
structlogger = structlog.get_logger()


@DefaultV1Recipe.register(
    [
        DefaultV1Recipe.ComponentType.COMMAND_GENERATOR,
    ],
    is_trainable=True,
)
class CustomLLMCommandGenerator(SingleStepLLMCommandGenerator):
    async def predict_commands(
        self,
        message: Message,
        flows: FlowsList,
        tracker: Optional[DialogueStateTracker] = None,
        **kwargs: Any,
    ) -> List[Command]:
        """Predict commands using the LLM.

        Args:
            message: The message from the user.
            flows: The flows available to the user.
            tracker: The tracker containing the current state of the conversation.

        Returns:
            The commands generated by the llm.
        """

        callback_with_metadata = partial(self.track_cost_callback, message=message)
        litellm.success_callback = [callback_with_metadata]
        commands = await super(CustomLLMCommandGenerator, self).predict_commands(message, flows, tracker)

        return commands

    @staticmethod
    async def track_cost_callback(
        kwargs,                 # kwargs to completion
        completion_response,    # response from completion
        start_time,
        end_time,    # start/end time
        message
    ):
        try:

            message.set(
                COMPLETION_TOKENS,
                completion_response.usage.completion_tokens,
                add_to_output=True,
            )
            message.set(
                PROMPT_TOKENS,
                completion_response.usage.prompt_tokens,
                add_to_output=True,
            )
            message.set(
                TOTAL_COST,
                kwargs["response_cost"],
                add_to_output=True,
            )
        except:
            pass
